{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "variables_and_state.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "FH3IRpYTta2v"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FH3IRpYTta2v"
      },
      "source": [
        "##### Copyright 2021 The IREE Authors"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mWGa71_Ct2ug",
        "cellView": "form"
      },
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h5s6ncerSpc5"
      },
      "source": [
        "# Variables and State\n",
        "\n",
        "This notebook\n",
        "\n",
        "1. Creates a TensorFlow program with basic tf.Variable use\n",
        "2. Imports that program into IREE's compiler\n",
        "3. Compiles the imported program to an IREE VM bytecode module\n",
        "4. Tests running the compiled VM module using IREE's runtime\n",
        "5. Downloads compilation artifacts for use with the native (C API) sample application"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s2bScbYkP6VZ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "a0d86566-9ae9-4b2a-bbf0-893ff72980d1"
      },
      "source": [
        "#@title General setup\n",
        "\n",
        "import os\n",
        "import tempfile\n",
        "\n",
        "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
        "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n",
        "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using artifacts directory '/tmp/iree/colab_artifacts'\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dBHgjTjGPOJ7"
      },
      "source": [
        "## Create a program using TensorFlow and import it into IREE\n",
        "\n",
        "This program uses `tf.Variable` to track state internal to the program then exports functions which can be used to interact with that variable.\n",
        "\n",
        "Note that each function we want to be callable from our compiled program needs\n",
        "to use `@tf.function` with an `input_signature` specified.\n",
        "\n",
        "References:\n",
        "\n",
        "* [\"Introduction to Variables\" Guide](https://www.tensorflow.org/guide/variable)\n",
        "* [`tf.Variable` reference](https://www.tensorflow.org/api_docs/python/tf/Variable)\n",
        "* [`tf.function` reference](https://www.tensorflow.org/api_docs/python/tf/function)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hwApbPstraWZ"
      },
      "source": [
        "#@title Define a simple \"counter\" TensorFlow module\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "class CounterModule(tf.Module):\n",
        "  def __init__(self):\n",
        "    super().__init__()\n",
        "    self.counter = tf.Variable(0)\n",
        "\n",
        "  @tf.function(input_signature=[])\n",
        "  def get_value(self):\n",
        "    return self.counter\n",
        "    \n",
        "  @tf.function(input_signature=[tf.TensorSpec([], tf.int32)])\n",
        "  def set_value(self, new_value):\n",
        "    self.counter.assign(new_value)\n",
        "    \n",
        "  @tf.function(input_signature=[tf.TensorSpec([], tf.int32)])\n",
        "  def add_to_value(self, x):\n",
        "    self.counter.assign(self.counter + x)\n",
        "\n",
        "  @tf.function(input_signature=[])\n",
        "  def reset_value(self):\n",
        "    self.set_value(0)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k4aMPI2C7btB"
      },
      "source": [
        "%%capture\n",
        "!python -m pip install iree-compiler-snapshot iree-tools-tf-snapshot -f https://github.com/google/iree/releases"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3nSXZiZ_X8-P",
        "outputId": "fbd021ef-3da9-475d-87b9-00e0e8139bf3"
      },
      "source": [
        "#@title Import the TensorFlow program into IREE as MLIR\n",
        "\n",
        "from IPython.display import clear_output\n",
        "\n",
        "from iree.compiler import tf as tfc\n",
        "\n",
        "compiler_module = tfc.compile_module(\n",
        "    CounterModule(), import_only=True, output_mlir_debuginfo=False)\n",
        "clear_output()  # Skip over TensorFlow's output.\n",
        "\n",
        "# Print the imported MLIR to see how the compiler views this TensorFlow program.\n",
        "# Note IREE's `util.global` ops and the public (exported) functions.\n",
        "print(\"Counter MLIR:\\n```\\n%s```\\n\" % compiler_module.decode(\"utf-8\"))\n",
        "\n",
        "# Save the imported MLIR to disk.\n",
        "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"counter.mlir\")\n",
        "with open(imported_mlir_path, \"wt\") as output_file:\n",
        "  output_file.write(compiler_module.decode(\"utf-8\"))\n",
        "print(f\"Wrote MLIR to path '{imported_mlir_path}'\")"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Counter MLIR:\n",
            "```\n",
            "#map = affine_map<() -> ()>\n",
            "module  {\n",
            "  util.global private mutable @counter = dense<0> : tensor<i32>\n",
            "  func @add_to_value(%arg0: !hal.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n",
            "    %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<i32>\n",
            "    call @__inference_add_to_value_100(%0) : (tensor<i32>) -> ()\n",
            "    return\n",
            "  }\n",
            "  func private @__inference_add_to_value_100(%arg0: tensor<i32> {tf._user_specified_name = \"x\"}) attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {\n",
            "    %0 = util.global.load @counter : tensor<i32>\n",
            "    %1 = linalg.init_tensor [] : tensor<i32>\n",
            "    %2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%0, %arg0 : tensor<i32>, tensor<i32>) outs(%1 : tensor<i32>) {\n",
            "    ^bb0(%arg1: i32, %arg2: i32, %arg3: i32):  // no predecessors\n",
            "      %3 = arith.addi %arg1, %arg2 : i32\n",
            "      linalg.yield %3 : i32\n",
            "    } -> tensor<i32>\n",
            "    util.global.store %2, @counter : tensor<i32>\n",
            "    return\n",
            "  }\n",
            "  func @get_value() -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n",
            "    %0 = util.global.load @counter : tensor<i32>\n",
            "    %1 = hal.tensor.cast %0 : tensor<i32> -> !hal.buffer_view\n",
            "    return %1 : !hal.buffer_view\n",
            "  }\n",
            "  func @reset_value() attributes {iree.abi = \"{\\22a\\22:[],\\22r\\22:[],\\22v\\22:1}\"} {\n",
            "    %cst = arith.constant dense<0> : tensor<i32>\n",
            "    util.global.store %cst, @counter : tensor<i32>\n",
            "    return\n",
            "  }\n",
            "  func @set_value(%arg0: !hal.buffer_view) attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\"} {\n",
            "    %0 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<i32>\n",
            "    util.global.store %0, @counter : tensor<i32>\n",
            "    return\n",
            "  }\n",
            "}\n",
            "\n",
            "```\n",
            "\n",
            "Wrote MLIR to path '/tmp/iree/colab_artifacts/counter.mlir'\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WCiRV6KRh3iA"
      },
      "source": [
        "## Test the imported program\n",
        "\n",
        "_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._\n",
        "\n",
        "_See the [README](https://github.com/google/iree/tree/main/iree/samples/variables_and_state#changing-compilation-options) for more details and example command line instructions._\n",
        "\n",
        "* _The \"imported MLIR\" can be used by IREE's generic compiler tools_\n",
        "* _The \"flatbuffer blob\" can be saved and used by runtime applications_\n",
        "\n",
        "_The specific point at which you switch from Python to native tools will depend on your project._"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6TV6_Hdu6Xlf"
      },
      "source": [
        "%%capture\n",
        "!python -m pip install iree-compiler-snapshot -f https://github.com/google/iree/releases"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GF0dzDsbaP2w",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "850d68d0-36e5-498c-fa83-b93e90d1adbc"
      },
      "source": [
        "#@title Compile the imported MLIR further into an IREE VM bytecode module\n",
        "\n",
        "from iree.compiler import compile_str\n",
        "\n",
        "flatbuffer_blob = compile_str(compiler_module, target_backends=[\"vmvx\"])\n",
        "\n",
        "# Save the imported MLIR to disk.\n",
        "flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"counter_vmvx.vmfb\")\n",
        "with open(flatbuffer_path, \"wb\") as output_file:\n",
        "  output_file.write(flatbuffer_blob)\n",
        "print(f\"Wrote .vmfb to path '{flatbuffer_path}'\")"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Wrote .vmfb to path '/tmp/iree/colab_artifacts/counter_vmvx.vmfb'\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G7g5eXYL6hWb"
      },
      "source": [
        "%%capture\n",
        "!python -m pip install iree-runtime-snapshot -f https://github.com/google/iree/releases"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h8cmF6nAfza0",
        "outputId": "62636478-965a-4b5c-c170-28a5c84f4af1"
      },
      "source": [
        "#@title Test running the compiled VM module using IREE's runtime\n",
        "\n",
        "from iree import runtime as ireert\n",
        "\n",
        "vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)\n",
        "config = ireert.Config(\"vmvx\")\n",
        "ctx = ireert.SystemContext(config=config)\n",
        "ctx.add_vm_module(vm_module)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Created IREE driver vmvx: <iree.runtime.binding.HalDriver object at 0x7fb8d992d070>\n",
            "SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7fb8d992d070>\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CQffg1iQatkb",
        "outputId": "689e2314-6523-470d-9faf-93acc4362e56"
      },
      "source": [
        "# Our @tf.functions are accessible by name on the module named 'module'\n",
        "counter = ctx.modules.module\n",
        "\n",
        "print(counter.get_value())\n",
        "counter.set_value(101)\n",
        "print(counter.get_value())\n",
        "\n",
        "counter.add_to_value(20)\n",
        "print(counter.get_value())\n",
        "counter.add_to_value(-50)\n",
        "print(counter.get_value())\n",
        "\n",
        "counter.reset_value()\n",
        "print(counter.get_value())"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "0\n",
            "101\n",
            "121\n",
            "71\n",
            "0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wCvwX1IEokm6"
      },
      "source": [
        "## Download compilation artifacts"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bUaNUkS2ohRj"
      },
      "source": [
        "ARTIFACTS_ZIP = \"/tmp/variables_and_state_colab_artifacts.zip\"\n",
        "\n",
        "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n",
        "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n",
        "\n",
        "# Note: you can also download files using Colab's file explorer\n",
        "try:\n",
        "  from google.colab import files\n",
        "  print(\"Downloading the artifacts zip file...\")\n",
        "  files.download(ARTIFACTS_ZIP)\n",
        "except ImportError:\n",
        "  print(\"Missing google_colab Python package, can't download files\")"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
